{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Anaconda3\\envs\\machine_learning\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda:0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x28ea8f1f9d0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "import h5py\n",
    "import copy\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from torchvision.models import *\n",
    "from torchvision.transforms import *\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from torch.optim import *\n",
    "from torch.optim.lr_scheduler import *\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from Utils.utils import *\n",
    "from Utils.dataset import *\n",
    "from Utils.model_loading import *\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using {device}\")\n",
    "\n",
    "random.seed(0)\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Different Models\n",
    "\n",
    "- IMPORTANT: Replace the output layer correctly before training or face suffering"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_size = 224\n",
    "dataloader = build_dataloader(\n",
    "    train_path = \"Data/mitbih_mif_train_small.h5\",\n",
    "    test_path  = \"Data/mitbih_mif_test.h5\",\n",
    "    batch_size = 32,\n",
    "    transform  = Resize((image_size, image_size), antialias=None)\n",
    ")\n",
    "\n",
    "# visualize_ecg_data(dataloader[\"train\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finetuning\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Accuracy 97.36% / Best Accuracy: 97.36%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 Accuracy 95.74% / Best Accuracy: 97.36%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3 Accuracy 97.03% / Best Accuracy: 97.36%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4 Accuracy 97.83% / Best Accuracy: 97.83%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5 Accuracy 97.31% / Best Accuracy: 97.83%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6 Accuracy 98.09% / Best Accuracy: 98.09%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7 Accuracy 98.11% / Best Accuracy: 98.11%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8 Accuracy 97.35% / Best Accuracy: 98.11%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9 Accuracy 97.55% / Best Accuracy: 98.11%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10 Accuracy 97.35% / Best Accuracy: 98.11%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11 Accuracy 98.31% / Best Accuracy: 98.31%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12 Accuracy 98.61% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13 Accuracy 96.06% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14 Accuracy 98.14% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15 Accuracy 98.31% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16 Accuracy 98.31% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17 Accuracy 98.39% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18 Accuracy 98.43% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19 Accuracy 98.33% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20 Accuracy 98.53% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 21 Accuracy 98.20% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 22 Accuracy 98.39% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 23 Accuracy 98.41% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 24 Accuracy 98.31% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 25 Accuracy 97.94% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 26 Accuracy 96.18% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 27 Accuracy 98.40% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 28 Accuracy 98.10% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29 Accuracy 98.59% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 30 Accuracy 98.55% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 31 Accuracy 98.54% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 32 Accuracy 98.47% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 33 Accuracy 98.41% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 34 Accuracy 98.41% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 35 Accuracy 98.41% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 36 Accuracy 98.39% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 37 Accuracy 98.33% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 38 Accuracy 98.30% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 39 Accuracy 97.40% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 40 Accuracy 98.40% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 41 Accuracy 98.55% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 42 Accuracy 98.47% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 43 Accuracy 98.42% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 44 Accuracy 98.49% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 45 Accuracy 98.55% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 46 Accuracy 98.55% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 47 Accuracy 98.53% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 48 Accuracy 98.49% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 49 Accuracy 98.56% / Best Accuracy: 98.61%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50 Accuracy 98.63% / Best Accuracy: 98.63%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                       "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of Loaded Model: 98.63\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "model_name = \"vgg16_bn\"\n",
    "model = load_model(model_name)\n",
    "\n",
    "# Replace classification layer output\n",
    "num_classes = ArrhythmiaLabels.size\n",
    "model.classifier[6] = nn.Linear(4096, num_classes)\n",
    "\n",
    "# Finetuning\n",
    "num_finetune_epochs = 50\n",
    "model.to(device)\n",
    "optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)\n",
    "scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "best_model_checkpoint = dict()\n",
    "best_accuracy = 0\n",
    "safety = True\n",
    "\n",
    "if not os.path.exists(\"Pretrained\"):\n",
    "    os.makedirs(\"Pretrained\")\n",
    "\n",
    "print(\"Finetuning\")\n",
    "for epoch in range(num_finetune_epochs):\n",
    "    train(model, dataloader[\"train\"], criterion, optimizer, scheduler)\n",
    "    accuracy = evaluate(model, dataloader[\"test\"])\n",
    "\n",
    "    if accuracy > best_accuracy:\n",
    "        best_model_checkpoint[\"state_dict\"] = copy.deepcopy(model.state_dict())\n",
    "        best_accuracy = accuracy\n",
    "    \n",
    "    if safety and (epoch + 1) % 10 == 0:\n",
    "        checkpoint = {\n",
    "            \"model_state_dict\":     model.state_dict(),\n",
    "            \"optimizer_state_dict\": optimizer.state_dict(),\n",
    "            \"scheduler_state_dict\": scheduler.state_dict(),\n",
    "            \"epoch\":                epoch,\n",
    "        }\n",
    "        if not os.path.exists(\"Pretrained/Safety\"):\n",
    "            os.makedirs(\"Pretrained/Safety\")\n",
    "        torch.save(checkpoint, f\"Pretrained/Safety/check_{epoch+1}.pth\")\n",
    "    \n",
    "    print(f\"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%\")\n",
    "\n",
    "# Save best model    \n",
    "save_path = f\"Pretrained/{model_name}_ecg_ep{num_finetune_epochs}_i{image_size}.pth\"\n",
    "torch.save(best_model_checkpoint[\"state_dict\"], save_path)\n",
    "\n",
    "# Test saved model\n",
    "model = load_model_from_pretrained(model_name, save_path, num_classes=num_classes)\n",
    "model.to(device)\n",
    "acc = evaluate(model, dataloader[\"test\"])\n",
    "print(f\"Accuracy of Loaded Model: {acc:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Resolution Scaling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_with_loss_record(\n",
    "    model:      nn.Module,\n",
    "    dataloader: DataLoader,\n",
    "    criterion:  nn.Module,\n",
    "    optimizer:  Optimizer,\n",
    "    scheduler:  LambdaLR\n",
    ") -> list:\n",
    "    model.train()\n",
    "    running_loss = []\n",
    "\n",
    "    for inputs, labels in tqdm(dataloader, desc=\"Train\", leave=False):\n",
    "        inputs = inputs.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        running_loss.append(loss.item())\n",
    "\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "\n",
    "    return running_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name  = \"mobilenet_v3_small\"\n",
    "num_classes = ArrhythmiaLabels.size\n",
    "num_finetune_epochs = 15\n",
    "resolutions = [80, 104, 128, 152, 176, 200, 224]\n",
    "running_accuracy = {r: [] for r in resolutions}\n",
    "running_loss = {l: [] for l in resolutions}\n",
    "\n",
    "for resolution in resolutions:\n",
    "    # Setup training\n",
    "    model = load_model(model_name)\n",
    "    model.classifier[3] = nn.Linear(1024, num_classes)\n",
    "    model.to(device)\n",
    "\n",
    "    dataloader = build_dataloader(\n",
    "        train_path = \"Data/mitbih_mif_train_small.h5\",\n",
    "        test_path  = \"Data/mitbih_mif_test.h5\",\n",
    "        batch_size = 128,\n",
    "        transform  = Resize((resolution, resolution))\n",
    "    )\n",
    "    \n",
    "    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)\n",
    "    scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    best_model_checkpoint = dict()\n",
    "    best_accuracy = 0\n",
    "\n",
    "    # Finetune\n",
    "    print(f\"Finetuning at resolution {resolution}\")\n",
    "    for epoch in range(num_finetune_epochs):\n",
    "        losses = train_with_loss_record(model, dataloader[\"train\"], criterion, optimizer, scheduler)\n",
    "        accuracy = evaluate(model, dataloader[\"test\"])\n",
    "        \n",
    "        running_accuracy[resolution].append(accuracy)\n",
    "        running_loss[resolution].append(losses)\n",
    "\n",
    "        if accuracy > best_accuracy:\n",
    "            best_model_checkpoint[\"state_dict\"] = copy.deepcopy(model.state_dict())\n",
    "            best_accuracy = accuracy\n",
    "        \n",
    "        print(f\"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%\")\n",
    "\n",
    "    # Save best model\n",
    "    if not os.path.exists(\"Pretrained/MobileNetV3-Small\"):\n",
    "        os.makedirs(\"Pretrained/MobileNetV3-Small\")\n",
    "    \n",
    "    save_path = f\"Pretrained/MobileNetV3-Small/{model_name}_ecg_ep{num_finetune_epochs}_i{resolution}.pth\"\n",
    "    torch.save(best_model_checkpoint[\"state_dict\"], save_path)\n",
    "\n",
    "    # Write the running accuracy and loss to another file for safety\n",
    "    with open(\"running_acc.txt\", \"a\") as file:\n",
    "        file.write(\" \".join(map(str, running_accuracy[resolution])))\n",
    "        file.write(\"\\n\")\n",
    "    \n",
    "    with open(\"running_loss.txt\", \"a\") as file:\n",
    "        total_losses = []\n",
    "        for epoch_losses in running_loss[resolution]:\n",
    "            total_losses += epoch_losses\n",
    "\n",
    "        file.write(\" \".join(map(str, total_losses)))\n",
    "        file.write(\"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OFA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ofa_specialized_get = torch.hub.load(\"mit-han-lab/once-for-all\", \"ofa_specialized_get\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, image_size = ofa_specialized_get(\"pixel1_lat@20ms_top1@71.4_finetune@25\", pretrained=True)\n",
    "\n",
    "dataloader = build_dataloader(\n",
    "    train_path = \"Data/mitbih_mif_train_small.h5\",\n",
    "    test_path  = \"Data/mitbih_mif_test.h5\",\n",
    "    batch_size = 128,\n",
    "    transform  = Resize((152, 152))\n",
    ")\n",
    "\n",
    "model.classifier.linear = nn.Linear(1280, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name  = \"ofa_pixel1_20\"\n",
    "num_classes = ArrhythmiaLabels.size\n",
    "\n",
    "# Finetuning\n",
    "num_finetune_epochs = 50\n",
    "model.to(device)\n",
    "optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)\n",
    "scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "best_model_checkpoint = dict()\n",
    "best_accuracy = 0\n",
    "\n",
    "print(\"Finetuning\")\n",
    "for epoch in range(num_finetune_epochs):\n",
    "    train(model, dataloader[\"train\"], criterion, optimizer, scheduler)\n",
    "    accuracy = evaluate(model, dataloader[\"test\"])\n",
    "\n",
    "    if accuracy > best_accuracy:\n",
    "        best_model_checkpoint[\"state_dict\"] = copy.deepcopy(model.state_dict())\n",
    "        best_accuracy = accuracy\n",
    "    \n",
    "    print(f\"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%\")\n",
    "\n",
    "# Save best model\n",
    "if not os.path.exists(\"Pretrained\"):\n",
    "    os.makedirs(\"Pretrained\")\n",
    "    \n",
    "save_path = f\"Pretrained/{model_name}_ecg_ep{num_finetune_epochs}_i{image_size}.pth\"\n",
    "torch.save(best_model_checkpoint[\"state_dict\"], save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test saved model\n",
    "model, _ = ofa_specialized_get(\"flops@595M_top1@80.0_finetune@75\", pretrained=True)\n",
    "model.classifier.linear = nn.Linear(1536, 5)\n",
    "model.load_state_dict(torch.load(\"Pretrained\\ofa_595M_ecg_ep50_i152.pth\"))\n",
    "model.to(device)\n",
    "acc = evaluate(model, dataloader[\"test\"])\n",
    "print(f\"Accuracy of Loaded Model: {acc:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "machine_learning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
